package db_gorm

import (
	"context"
	"errors"
	"fmt"
	plumeErrors "gitee.com/lipore/plume/errors"
	plumeLogger "gitee.com/lipore/plume/logger"
	"gorm.io/gorm"
	"gorm.io/gorm/logger"
	"gorm.io/gorm/utils"
	"time"
)

const (
	traceWarnStr = "%+v\n[%.3fms] [rows:%v] %s"
	traceErrStr  = "%+v\n[%.3fms] [rows:%v] %s"
	traceStr     = "%+v\n[%.3fms] [rows:%v] %s"
)

type LoggerOptions struct {
	SlowThreshold             time.Duration
	IgnoreRecordNotFoundError bool
	LogLevel                  plumeLogger.Level
}

type Logger struct {
	writer                    plumeLogger.Logger
	level                     logger.LogLevel
	ignoreRecordNotFoundError bool
	slowThreshold             time.Duration
}

func logLevelMap(level plumeLogger.Level) logger.LogLevel {
	switch level {
	case plumeLogger.DEBUG:
		return logger.Info
	case plumeLogger.INFO:
		return logger.Info
	case plumeLogger.WARN:
		return logger.Warn
	case plumeLogger.ERROR:
		return logger.Error
	case plumeLogger.PANIC:
		return logger.Error
	case plumeLogger.FATAL:
		return logger.Error
	default:
		return logger.Warn
	}
}

func NewLogger(options *LoggerOptions) *Logger {
	return &Logger{
		writer:                    plumeLogger.GetLogger(),
		level:                     logLevelMap(options.LogLevel),
		ignoreRecordNotFoundError: options.IgnoreRecordNotFoundError,
		slowThreshold:             options.SlowThreshold,
	}
}

func (l *Logger) LogMode(level logger.LogLevel) logger.Interface {
	newLogger := *l
	newLogger.level = level
	return &newLogger
}

func (l *Logger) Info(_ context.Context, s string, i ...interface{}) {
	if l.level >= logger.Info {
		l.writer.Info(fmt.Sprintf(s, i...))
	}
}

func (l *Logger) Warn(_ context.Context, s string, i ...interface{}) {
	if l.level >= logger.Warn {
		l.writer.Warn(fmt.Sprintf(s, i...))
	}
}

func (l *Logger) Error(_ context.Context, s string, i ...interface{}) {
	if l.level >= logger.Error {
		l.writer.Error(fmt.Sprintf(s, i...))
	}
}

func (l *Logger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {

	if l.level <= logger.Silent {
		return
	}

	elapsed := time.Since(begin)
	switch {
	case err != nil && l.level >= logger.Error && (!errors.Is(err, gorm.ErrRecordNotFound) || !l.ignoreRecordNotFoundError):
		sql, rows := fc()
		if rows == -1 {
			l.Error(ctx, traceErrStr, plumeErrors.WithCode(err, 0, ""), float64(elapsed.Nanoseconds())/1e6, "-", sql)
		} else {
			l.Error(ctx, traceErrStr, plumeErrors.WithCode(err, 0, ""), float64(elapsed.Nanoseconds())/1e6, rows, sql)
		}
	case elapsed > l.slowThreshold && l.slowThreshold != 0 && l.level >= logger.Warn:
		sql, rows := fc()
		slowLog := fmt.Sprintf("SLOW SQL >= %v", l.slowThreshold)
		if rows == -1 {
			l.Warn(ctx, traceWarnStr, plumeErrors.WithCode(nil, 0, slowLog), float64(elapsed.Nanoseconds())/1e6, "-", sql)
		} else {
			l.Warn(ctx, traceWarnStr, plumeErrors.WithCode(err, 0, slowLog), float64(elapsed.Nanoseconds())/1e6, rows, sql)
		}
	case l.level == logger.Info:
		sql, rows := fc()
		if rows == -1 {
			l.Info(ctx, traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql)
		} else {
			l.Info(ctx, traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql)
		}
	}
}
